#!/usr/bin/env Rscript
# the script generates visualization figures 
{
  # first re-initialize target model
  {
    setwd('~/Dropbox/Directory');  # set to location of replication files 
    conda_env <- "LinkOrgsEnv"; backend <- "METAL"
    options(timeout = max(60*5, getOption("timeout")))
    ModelLoc <- gsub(ModelZipLoc <- sprintf('%s/Model_%s.zip', DownloadFolder <- "./DataOutputs", ml_version<-"v4" ),
                     pattern = "\\.zip", replace = "")
    WeightsLoc <- sprintf('%s/ModelWeights_%s.eqx', DownloadFolder, ml_version)
    CharIndicatorsLoc <- sprintf('%s/CharIndicatorsLoc.csv', DownloadFolder)
    
    # load in model 
    {
      ModelURL <- "https://www.dropbox.com/scl/fi/jt8t101x9ar1s5zgxjijo/Analysis.zip?rlkey=dmtyu8hcjeqcbvvxw1mfg22y2&dl=0"
      WeightsURL <- "https://www.dropbox.com/scl/fi/zr4bziggj3nugrpovkxrm/LinkOrgsBase_17PT3M_2024-02-29_ilast.eqx?rlkey=b6f7i8dhuro62hlszm365vofi&dl=0"
      
      # process URLs
      ModelURL <- LinkOrgs::dropboxURL2downloadURL(ModelURL);
      WeightsURL <- LinkOrgs::dropboxURL2downloadURL(WeightsURL)
      
      # download weights
      download.file( WeightsURL, destfile = WeightsLoc )
      
      # download and unzip model
      download.file( ModelURL, destfile = ModelZipLoc )
      unzip(ModelZipLoc, exdir = ModelLoc)
      
      # download characters & save
      charIndicators <- LinkOrgs::url2dt("https://www.dropbox.com/scl/fi/1jh8nrwsucfzj2gy9rydy/charIndicators.csv.zip?rlkey=wkhqk9x3550l364xbvnvnkoem&dl=0")
      data.table::fwrite(charIndicators, file = CharIndicatorsLoc)
    }
    
    # build model
    print("Re-building ML model...")
    trainModel <- F; AnalysisName <- "LinkOrgs"
    charIndicators <- as.matrix( data.table::fread(file = CharIndicatorsLoc) )
    source( sprintf('%s/Analysis/LinkOrgs_Helpers.R', ModelLoc), local = T )
    source( sprintf('%s/Analysis/JaxTransformer_Imports.R', ModelLoc), local = T )
    source( sprintf('%s/Analysis/JaxTransformer_BuildML.R', ModelLoc), local = T )
    source( sprintf('%s/Analysis/JaxTransformer_TrainDefine.R', ModelLoc), local = T )
    
    # obtain trained weights
    print("Applying trained weights...")
    ModelList <- eq$tree_deserialise_leaves( WeightsLoc, list(ModelList, StateList, opt_state) )
    StateList <- ModelList[[2]]; ModelList <- ModelList[[1]]
  }
  
  # next, get out of sample ml results
  {
    NegMatches_mat_hold <- as.data.frame( data.table::fread("./DataOutputs/NegMatches_mat_hold.csv"))
    PosMatches_mat_hold <- as.data.frame( data.table::fread("./DataOutputs/PosMatches_mat_hold.csv"))

    stringdist_pos <- stringdist::stringdist(a = unlist(PosMatches_mat_hold[,"name1"]), b = unlist(PosMatches_mat_hold[,"name2"]), q = 2)
    stringdist_neg <- stringdist::stringdist(a = unlist(NegMatches_mat_hold[,"name1"]), b = unlist(NegMatches_mat_hold[,"name2"]), q = 2)
    PosProbMat_out <- GetMatchProb_BigBatch( name1 = unlist(PosMatches_mat_hold[,"name1"]),
                                             name2 = unlist(PosMatches_mat_hold[,"name2"]),
                                             nBatch_BigBatch = (nBatch_BigBatch <- 50L) )
    NegProbMat_out <- GetMatchProb_BigBatch( name1 = unlist(NegMatches_mat_hold[,"name1"]),
                                             name2 = unlist(NegMatches_mat_hold[,"name2"]),
                                             nBatch_BigBatch = nBatch_BigBatch )
    ks.test(stringdist_pos, stringdist_neg)$statistic
    ks.test(c(PosProbMat_out$matchprob), c(NegProbMat_out$matchprob))$statistic
    summary( NegProbMat_out$matchprob ); summary( PosProbMat_out$matchprob )

    # high prob matches are those that share last character of last word
    # View(NegProbMat_out[NegProbMat_out$matchprob>0.45,])
    # View(PosProbMat_out[PosProbMat_out$matchprob>0.45,])

    # binds
    PosMatches_mat_hold <- cbind(PosMatches_mat_hold, "matchProb_est"=PosProbMat_out$matchprob)
    NegProbMat_out <- cbind(NegProbMat_out, "matchProb_est"=NegProbMat_out$matchprob)

    #View( head( PosMatches_mat_hold[order(PosProbMat_out$matchprob),c("name1",'name2','matchProb_est')], 100) )
    #View( tail( PosMatches_mat_hold[order(PosProbMat_out$matchprob),c("name1",'name2','matchProb_est')], 300) )
    #View( head( NegProbMat_out[order(NegProbMat_out$matchprob),c("name1",'name2','matchProb_est')], 20) )
    #View( tail( NegProbMat_out[order(NegProbMat_out$matchprob),c("name1",'name2','matchProb_est')], 20) )

    MeanTop20Neg_prob <- mean(tail( NegProbMat_out[order(NegProbMat_out$matchprob),c("name1",'name2','matchprob')], 20)[[3]])
    tmp <- PosProbMat_out[abs(PosProbMat_out$matchprob - MeanTop20Neg_prob)<0.02,]
  }
  
  # finally, make figures 
  { 
    # contrast plot - Figure 4, left panel 
    pdf("./Figures/TransformerVectorCors_baseline.pdf")
    {
      dp_ <- density(stringdist_pos); dn_<-density(stringdist_neg)
      dn_$y <-  dn_$y / max( dn_$y )
      dp_$y <-  dp_$y / max( dp_$y )
      par(mar=c(4,5,5,1)); plot( dn_,
                                 lwd=2,
                                 xlim = c(min(dn_$x,dp_$x),max(dn_$x,dp_$x)),
                                 ylim = c(0,max(dn_$y,dp_$y)*1.2),col="gray",log="",
                                 main = "Fuzzy Matching Baseline",cex.lab = 2,
                                 xlab = "Jaccard Distance",cex.main = 2)
      points(dp_ ,lwd=2,type="l" )
      text(dp_$x[which.max(dp_$y)],max(dp_$y)+max(dp_$y)*.05,labels = "Matches",cex=1.5)
      text(dn_$x[which.max(dn_$y)],max(dn_$y)+max(dp_$y)*.1,labels = "Non-matches",cex=1.5,col="gray")
    }
    dev.off()

    # contrast plot - Figure 4, right panel 
    pdf("./Figures/TransformerVectorCors.pdf")
    {
      dp_ <- density(PosProbMat_out$matchprob); dn_<-density(NegProbMat_out$matchprob)
      dn_$y <-  dn_$y / max( dn_$y )
      dp_$y <-  dp_$y / max( dp_$y )
      par(mar=c(4,5,5,1)); plot( dn_,
                                 lwd=2,
                                 xlim = c(min(dn_$x,dp_$x),max(dn_$x,dp_$x)),
                                 ylim = c(0,max(dn_$y,dp_$y)*1.2),col="gray",log="",
                                 main = "Predicted Match Probabilities",cex.lab = 2,
                                 xlab = "Predicted Probability",cex.main = 2)
      points(dp_ ,lwd=2,type="l" )
      text(dp_$x[which.max(dp_$y)],max(dp_$y)+max(dp_$y)*.05,labels = "Matches",cex=1.5)
      text(dn_$x[which.max(dn_$y)],max(dn_$y)+max(dp_$y)*.05,labels = "Non-matches",cex=1.5,col="gray")
    }
    dev.off()

    # case study plot - Figure 3 
    pdf("./Figures/PCFigDistances.pdf")
    {
        n1_vec <- tolower(enc2utf8(c("oracle corporation", "Chase", "WaMu", "Goldman")))
        n2_vec <- tolower(enc2utf8(c("oracle", "JP Morgan Chase", "Washington Mutual", "Goldman Sachs Group, Inc.")))
        myReps <- GetAliasRep_BigBatch(tolower(c(n1_vec, n2_vec)))
        prcomp_ <- prcomp((myReps), scale = T, center = T)
        par(mar=c(5,5,3,1))
        plot( prcomp_allPredicts <- predict(prcomp_, myReps),
              xlim = summary(prcomp_allPredicts[,1])[c(1,6)]*1.5,
              ylim = (ylim_ <- summary(prcomp_allPredicts[,2])[c(1,6)]*1.5),
              cex=0, xlab = "Principal Component 1", ylab = "Principal Component 2",cex.lab = 2)
        xy1 <- predict(prcomp_, myReps[1:length(n1_vec),])[,1:2];
        xy2 <- predict(prcomp_, myReps[-c(1:length(n1_vec)),])[,1:2]
        d12 <- rowSums( (xy1-xy2)^2)
        col_vec <- gray.colors(length(d12),start=0,end=0.6); col_vec <- col_vec[order(d12,decreasing = F)]
        col_vec[] <- "black"
        points( xy1,pch = 1:length(col_vec),cex=1.5,col=col_vec)
        points( xy2,pch = 1:length(col_vec),cex=1.5,col=col_vec)
        text( xy1[,1],xy1[,2]+(ep_<-diff(ylim_)*0.01),
              labels = n1_vec,cex=(cex_<-1),font=165,col=col_vec)
        text( xy2[,1],xy2[,2]+(-ep_),
              labels = n2_vec,cex=cex_,font=165,col=col_vec)
      }
    dev.off()
  }
}


